# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Filters predictions based on xattn model.

The file tasks as input:
- Predictions jsonl file from an initial stage model (e.g. BM25, dual encoder).
- Output of running scoring with xattn model.
- Metadata generated by `gen_inference_inputs.py`.

The file outputs a new jsonl examples file with only document titles classified
as relevant by the xattn model.
"""

import math

from absl import app
from absl import flags

from language.quest.common import example_utils
from language.quest.common import jsonl_utils
from language.quest.common import tsv_utils
from language.quest.xattn import xattn_utils

FLAGS = flags.FLAGS

flags.DEFINE_string(
    "predictions", "", "Examples jsonl file with predicted scores."
)

flags.DEFINE_string("xattn_outputs", "", "Predictions from xattn model.")

flags.DEFINE_string("metadata", "",
                    "Input tsv file of metadata for xattn predictions.")

flags.DEFINE_string("output", "", "Output examples jsonl file.")

flags.DEFINE_float("threshold", 0.9, "Relevance threshold.")


def _filter_docs(example, relevant_query_docs):
  doc_titles = []
  for doc_title in example.docs:
    if (example.query, doc_title) in relevant_query_docs:
      doc_titles.append(doc_title)
  return doc_titles


def main(unused_argv):
  # Load xattn predictions and metadata.
  metadata = tsv_utils.read_tsv(FLAGS.metadata)
  xattn_outputs = jsonl_utils.read(FLAGS.xattn_outputs)

  # Process xattn model outputs.
  # Set of (query, doc_title) pairs marked relevant by xattn model.
  relevant_query_docs = set()
  for (query, doc_title), output in zip(metadata, xattn_outputs):
    if output["inputs"]["targets_pretokenized"] != xattn_utils.POS_LABEL:
      raise ValueError("Unexpected label: %s" % output["prediction"])
    prob = math.exp(output["score"])
    if prob > FLAGS.threshold:
      relevant_query_docs.add((query, doc_title))

  # Load initial predictions.
  init_predictions = example_utils.read_examples(FLAGS.predictions)
  # Filter docs in initial predictions.
  new_predictions = []
  for init_prediction in init_predictions:
    new_docs = _filter_docs(init_prediction, relevant_query_docs)
    # Copy over the original example but filter the docs.
    # Note that some metadata fields may now reference docs that have been
    # removed, but hopefully this does not cause any issues.
    new_prediction = example_utils.Example(
        query=init_prediction.query,
        docs=new_docs,
        scores=None,
        metadata=init_prediction.metadata)
    new_predictions.append(new_prediction)

  # Write predictions with filtered document set.
  example_utils.write_examples(FLAGS.output, new_predictions)


if __name__ == "__main__":
  app.run(main)
